from abc import ABC, abstractmethod

from ase import Atoms

from AIMD import arguments
from AIMD.fragment import FragmentData
from AIMD.protein import Protein


class BaseFragment(ABC):
    r"""
    Basic fragment Object.
    """

    def __init__(self) -> None:
        pass

    @abstractmethod
    def fragment(self, *args, **kwargs):
        r"""
        The basic fragment method. Fragment the protein into specific
        sub-structures (e.g. to generate dipeptides and ACE-NMEs).
        Please rewrite this function in subclass according to your specific
        fragment method.
        """
        pass


class DipeptideFragment(BaseFragment):
    r"""
    A subclass of BaseFragment to generate dipeptides and ACE-NMEs. All the
    dipeptides and ACE-NMEs are generated by the same method, i.e., @method
    get_fragment_index, but with different postprocessing in @method fragment.
    The @method fragment must invoke @method get_fragment_index.
    """

    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def get_fragments(self, prot: Protein) -> FragmentData:
        pass

    @staticmethod
    def get_fragments_index(prot: Atoms) -> tuple[list[int], list[int]]:
        r"""
        Fragment the protein into dipeptides/ACE-NMEs and return their indices
        in the original protein.

        Parameters:
        -----------
            prot: Atoms
                The protein to be fragmented.

        Returns:
        --------
            dipeptides: list[list[int]]
                The index of dipeptides in the protein.
            acenmes: list[list[int]]
                The index of ACE-NMEs in the protein.

        """

        solvent_method = arguments.get().solvent_method

        # prot.arrays['residuenumbers'] starts from 1
        num_residue = max(prot.arrays["residuenumbers"])

        assert (
            len(set(prot.arrays["residuenumbers"])) == num_residue
        ), "residue numbers are not continuous"

        # * One protein with $N$ residuals can be fragmented into $N-2$
        # dipeptides and $N-3$ ACE-NMEs
        num_dipeptides = num_residue - 2
        num_acenmes = num_residue - 3

        dipeptides: list[list[int]] = [[] for _ in range(num_dipeptides)]
        sidechains: list[list[int]] = [[] for _ in range(num_dipeptides)]
        acenmes: list[list[int]] = [[] for _ in range(num_acenmes)]

        residuenames = prot.arrays["residuenames"]
        residuenumbers = prot.arrays["residuenumbers"]
        atomtypes = prot.arrays["atomtypes"]

        for index in range(len(prot)):
            dipeptide_index = residuenumbers[index] - 2
            ace_index = residuenumbers[index] - 2
            nme_index = residuenumbers[index] - 3

            # 位于头部的ACE原子全部拷贝到第一个dipeptide。
            if str(residuenames[index]).strip() == "ACE":
                dipeptides[0].append(index)

            # 位于尾部的NME原子全部拷贝到最后一个dipeptide。
            elif str(residuenames[index]).strip() == "NME":
                dipeptides[-1].append(index)

            elif atomtypes[index] == "CA" or atomtypes[index][:2] == "HA":
                # CA, HA belong in the dipeptide and both ACE/NME

                # 氨基酸的C alpha原子拷贝到上一个dipeptide。
                if dipeptide_index > 0:
                    dipeptides[dipeptide_index - 1].append(index)
                # 氨基酸的C alpha原子拷贝到自身的dipeptide。
                dipeptides[dipeptide_index].append(index)
                # 氨基酸的C alpha原子拷贝到下一个dipeptide。
                if dipeptide_index < num_dipeptides - 1:
                    dipeptides[dipeptide_index + 1].append(index)

                if ace_index >= 0 and ace_index <= num_acenmes - 1:
                    acenmes[ace_index].append(index)
                if nme_index >= 0 and nme_index <= num_acenmes - 1:
                    acenmes[nme_index].append(index)

            elif atomtypes[index] == "C" or atomtypes[index] == "O":
                # C, O belong in the dipeptide and the ACE

                # 氨基酸的C、O原子拷贝到自身的dipeptide。
                dipeptides[dipeptide_index].append(index)
                # 氨基酸的C、O原子拷贝到下一个dipeptide。
                if dipeptide_index < num_dipeptides - 1:
                    dipeptides[dipeptide_index + 1].append(index)

                if ace_index >= 0 and ace_index <= num_acenmes - 1:
                    acenmes[ace_index].append(index)

            elif atomtypes[index] == "N" or atomtypes[index] == "H":
                # N, H belong in the dipeptide and the NME

                # 氨基酸的N、H原子拷贝到上一个dipeptide。
                if dipeptide_index > 0:
                    dipeptides[dipeptide_index - 1].append(index)
                # 氨基酸的N、H原子拷贝到自身的dipeptide。
                dipeptides[dipeptide_index].append(index)

                if nme_index >= 0 and nme_index <= num_acenmes - 1:
                    acenmes[nme_index].append(index)

            else:
                # 残基原子拷贝到自身的dipeptide。
                if solvent_method == 'AMOEBA':
                    sidechains[dipeptide_index].append(index)

        # tinker: insert sidechain into backbone, just before the second 'N'
        if solvent_method == 'AMOEBA':
            for idx, unit in enumerate(dipeptides):
                nitrogens = [i for i, index in enumerate(unit) if atomtypes[index] == 'N']
                assert len(nitrogens) == 2, "number of nitrogen atoms in dipeptide != 2"

                unit[nitrogens[1]:nitrogens[1]] = sidechains[idx]

        # print atom types of fragments
        if arguments.get().verbose >= 1:
            print(" [i] dipeptide fragments:")
            for idx, unit in enumerate(dipeptides):
                print(f"{idx:>8} | {' '.join([atomtypes[i] for i in unit])}")
            print(" [i] ACE-NME fragments:")
            for idx, unit in enumerate(acenmes):
                print(f"{idx:>8} | {' '.join([atomtypes[i] for i in unit])}")

        return dipeptides, acenmes
